import logging
from typing import Iterable, List, Optional

from pydantic import Field, SecretStr, model_validator

from datahub.configuration.common import ConfigModel
from datahub.configuration.source_common import (
    DatasetSourceConfigMixin,
    LowerCaseDatasetUrnConfigMixin,
)
from datahub.ingestion.api.common import PipelineContext
from datahub.ingestion.api.decorators import (
    SupportStatus,
    capability,
    config_class,
    platform_name,
    support_status,
)
from datahub.ingestion.api.source import MetadataWorkUnitProcessor, SourceCapability
from datahub.ingestion.api.workunit import MetadataWorkUnit
from datahub.ingestion.source.aws.aws_common import AwsConnectionConfig
from datahub.ingestion.source.common.subtypes import SourceCapabilityModifier
from datahub.ingestion.source.data_lake_common.config import PathSpecsConfigMixin
from datahub.ingestion.source.data_lake_common.data_lake_utils import PLATFORM_GCS
from datahub.ingestion.source.data_lake_common.object_store import (
    create_object_store_adapter,
)
from datahub.ingestion.source.data_lake_common.path_spec import PathSpec, is_gcs_uri
from datahub.ingestion.source.s3.config import DataLakeSourceConfig
from datahub.ingestion.source.s3.report import DataLakeSourceReport
from datahub.ingestion.source.s3.source import S3Source
from datahub.ingestion.source.state.stale_entity_removal_handler import (
    StaleEntityRemovalHandler,
    StatefulStaleMetadataRemovalConfig,
)
from datahub.ingestion.source.state.stateful_ingestion_base import (
    StatefulIngestionConfigBase,
    StatefulIngestionSourceBase,
)

logger: logging.Logger = logging.getLogger(__name__)

GCS_ENDPOINT_URL = "https://storage.googleapis.com"


class HMACKey(ConfigModel):
    hmac_access_id: str = Field(description="Access ID")
    hmac_access_secret: SecretStr = Field(description="Secret")


class GCSSourceConfig(
    StatefulIngestionConfigBase,
    DatasetSourceConfigMixin,
    PathSpecsConfigMixin,
    LowerCaseDatasetUrnConfigMixin,
):
    credential: HMACKey = Field(
        description="Google cloud storage [HMAC keys](https://cloud.google.com/storage/docs/authentication/hmackeys)",
    )

    max_rows: int = Field(
        default=100,
        description="Maximum number of rows to use when inferring schemas for TSV and CSV files.",
    )

    number_of_files_to_sample: int = Field(
        default=100,
        description="Number of files to list to sample for schema inference. This will be ignored if sample_files is set to False in the pathspec.",
    )

    stateful_ingestion: Optional[StatefulStaleMetadataRemovalConfig] = None

    @model_validator(mode="after")
    def check_path_specs_and_infer_platform(self) -> "GCSSourceConfig":
        if len(self.path_specs) == 0:
            raise ValueError("path_specs must not be empty")

        # Check that all path specs have the gs:// prefix.
        if any([not is_gcs_uri(path_spec.include) for path_spec in self.path_specs]):
            raise ValueError("All path_spec.include should start with gs://")

        return self


class GCSSourceReport(DataLakeSourceReport):
    pass


@platform_name("Google Cloud Storage", id=PLATFORM_GCS)
@config_class(GCSSourceConfig)
@support_status(SupportStatus.INCUBATING)
@capability(
    SourceCapability.CONTAINERS,
    "Enabled by default",
    subtype_modifier=[
        SourceCapabilityModifier.GCS_BUCKET,
        SourceCapabilityModifier.FOLDER,
    ],
)
@capability(SourceCapability.SCHEMA_METADATA, "Enabled by default")
@capability(SourceCapability.DATA_PROFILING, "Not supported", supported=False)
class GCSSource(StatefulIngestionSourceBase):
    def __init__(self, config: GCSSourceConfig, ctx: PipelineContext):
        super().__init__(config, ctx)
        self.config = config
        self.report = GCSSourceReport()
        self.platform: str = PLATFORM_GCS
        self.s3_source = self.create_equivalent_s3_source(ctx)

    @classmethod
    def create(cls, config_dict, ctx):
        config = GCSSourceConfig.model_validate(config_dict)
        return cls(config, ctx)

    def create_equivalent_s3_config(self):
        s3_path_specs = self.create_equivalent_s3_path_specs()

        s3_config = DataLakeSourceConfig(
            path_specs=s3_path_specs,
            aws_config=AwsConnectionConfig(
                aws_endpoint_url=GCS_ENDPOINT_URL,
                aws_access_key_id=self.config.credential.hmac_access_id,
                aws_secret_access_key=self.config.credential.hmac_access_secret.get_secret_value(),
                aws_region="auto",
            ),
            env=self.config.env,
            convert_urns_to_lowercase=self.config.convert_urns_to_lowercase,
            max_rows=self.config.max_rows,
            number_of_files_to_sample=self.config.number_of_files_to_sample,
            platform=PLATFORM_GCS,  # Ensure GCS platform is used for correct container subtypes
            platform_instance=self.config.platform_instance,
        )
        return s3_config

    def create_equivalent_s3_path_specs(self):
        s3_path_specs = []
        for path_spec in self.config.path_specs:
            # PathSpec modifies the passed-in include to add /** to the end if
            # autodetecting partitions. Remove that, otherwise creating a new
            # PathSpec will complain.
            # TODO: this should be handled inside PathSpec, which probably shouldn't
            # modify its input.
            include = path_spec.include
            if include.endswith("{table}/**") and not path_spec.allow_double_stars:
                include = include.removesuffix("**")

            s3_path_specs.append(
                PathSpec(
                    include=include.replace("gs://", "s3://"),
                    exclude=(
                        [exc.replace("gs://", "s3://") for exc in path_spec.exclude]
                        if path_spec.exclude
                        else None
                    ),
                    file_types=path_spec.file_types,
                    default_extension=path_spec.default_extension,
                    table_name=path_spec.table_name,
                    enable_compression=path_spec.enable_compression,
                    sample_files=path_spec.sample_files,
                    allow_double_stars=path_spec.allow_double_stars,
                    autodetect_partitions=path_spec.autodetect_partitions,
                    include_hidden_folders=path_spec.include_hidden_folders,
                    tables_filter_pattern=path_spec.tables_filter_pattern,
                    traversal_method=path_spec.traversal_method,
                )
            )

        return s3_path_specs

    def create_equivalent_s3_source(self, ctx: PipelineContext) -> S3Source:
        config = self.create_equivalent_s3_config()
        # Create a new context for S3 source without graph to avoid duplicate checkpointer registration
        s3_ctx = PipelineContext(run_id=ctx.run_id, pipeline_name=ctx.pipeline_name)
        s3_source = S3Source(config, s3_ctx)
        return self.s3_source_overrides(s3_source)

    def s3_source_overrides(self, source: S3Source) -> S3Source:
        """
        Override S3Source methods with GCS-specific implementations using the adapter pattern.

        This method customizes the S3Source instance to behave like a GCS source by
        applying the GCS-specific adapter that replaces the necessary functionality.

        Args:
            source: The S3Source instance to customize

        Returns:
            The modified S3Source instance with GCS behavior
        """
        # Create a GCS adapter with project ID and region from our config
        adapter = create_object_store_adapter(
            "gcs",
        )

        # Apply all customizations to the source
        return adapter.apply_customizations(source)

    def get_workunit_processors(self) -> List[Optional[MetadataWorkUnitProcessor]]:
        return [
            *super().get_workunit_processors(),
            StaleEntityRemovalHandler.create(
                self, self.config, self.ctx
            ).workunit_processor,
        ]

    def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]:
        return self.s3_source.get_workunits_internal()

    def get_report(self):
        return self.report
